from torchvision import transforms
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment as linear_assignment
import sys
import os
import random
import math
import time
import torch; torch.utils.backcompat.broadcast_warning.enabled = True
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
from scipy.fftpack import fft, rfft, fftfreq, irfft, ifft, rfftfreq
from scipy import signal
import numpy as np
import importlib
from transformers import ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import itertools
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import ViTFeatureExtractor
from torch.utils.data import DataLoader, Dataset
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch
from torch import nn


import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift
from torch.utils.data import DataLoader, TensorDataset, Dataset
from tqdm import tqdm
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.feature import local_binary_pattern
from braindecode.augmentation import FTSurrogate, SmoothTimeMask, ChannelsDropout

import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange


class ImgEncoder(nn.Module):
    def __init__(self, model_name="google/vit-base-patch16-224", pretrained=True, trainable=True):
        super().__init__()
        self.mlp_head = nn.Sequential(
            nn.Linear(768, 256)
        )
        if pretrained:
            self.model = ViTModel.from_pretrained(model_name)
        else:
            self.model = ViTModel(config=ViTConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable
            
        
            

    def forward(self, pixel_values):
        output = self.model(pixel_values=pixel_values)
        last_hidden_state = output.last_hidden_state
        cls_embedding = last_hidden_state[:, 0, :] 
        return self.mlp_head(cls_embedding)
    
    
    
train_image_data = torch.load("/home/ubuntu/train_image_data.pt")
test_image_data = torch.load("/home/ubuntu/test_image_data.pt")


train_labels = torch.load("/home/ubuntu/train_labels.pt")
test_labels = torch.load("/home/ubuntu/test_labels.pt")
print(f"Training data shape: {train_image_data.shape}")
print(f"Training labels shape: {train_labels.shape}")
print(f"Testing data shape: {test_image_data.shape}")
print(f"Testing labels shape: {test_labels.shape}")

train_dataset = TensorDataset(train_image_data, train_labels)
test_dataset = TensorDataset(test_image_data, test_labels)

batch_size = 256  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of testing batches: {len(test_loader)}")






device = "cuda"
Sur = FTSurrogate(probability=0.5, phase_noise_magnitude=1).to(device)
mask = SmoothTimeMask(probability=0.5, mask_len_samples=50).to(device)



# a little bit of cleaning along with changing the how things are loaded
def train(epoch, model, optimizer, loss_fn, miner, train_dataloader, device='cuda'):
    model.to(device)
    running_loss = []

    tq = tqdm(train_dataloader)
    

    for _, (image, labels) in enumerate(tq, start=1):
        image, labels = image.to(device), labels.to(device)
        optimizer.zero_grad()
        image, labels = image.to(device), labels.to(device)
        
        x_proj = model(image)
        hard_pairs = miner(x_proj, labels)
        loss = loss_fn(x_proj, labels, hard_pairs)
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        tq.set_description(f'Train:[{epoch}, {np.mean(running_loss):0.3f}]')

    return running_loss

def evaluate_clustering(epoch, model, dataloader, device='cuda'):
    model.eval()
    image_featvec_proj, labels_array = [], []

    with torch.no_grad():
        for image, labels in tqdm(dataloader):
            image, labels = image.to(device), labels.to(device)
            x_proj = model(image)
            image_featvec_proj.append(x_proj.cpu().numpy())
            labels_array.append(labels.cpu().numpy())

    image_featvec_proj = np.concatenate(image_featvec_proj, axis=0)
    labels_array = np.concatenate(labels_array, axis=0)

    num_clusters = 40
    k_means = K_means(n_clusters=num_clusters)
    clustering_acc_proj = k_means.transform(image_featvec_proj, labels_array)

    print(f"[Epoch: {epoch}, Train KMeans score Proj: {clustering_acc_proj}]")
    model.train()


def validation(epoch, model, optimizer, loss_fn, miner, val_dataloader):

    running_loss      = []
    image_featvec_proj  = np.array([])
    labels_array      = np.array([])

    tq = tqdm(val_dataloader)
    model.eval()
    for _, (image, labels) in enumerate(tq, start=1):
        image, labels = image.to(device), labels.to(device)
        with torch.no_grad():
            x_proj = model(image)
            hard_pairs = miner(x_proj, labels)
            loss       = loss_fn(x_proj, labels, hard_pairs)
            running_loss = running_loss + [loss.detach().cpu().numpy()]

        tq.set_description('Val:[{}, {:0.3f}]'.format(epoch, np.mean(running_loss)))

        image_featvec_proj = np.concatenate((image_featvec_proj, x_proj.cpu().detach().numpy()), axis=0) if image_featvec_proj.size else x_proj.cpu().detach().numpy()
        labels_array     = np.concatenate((labels_array, labels.cpu().detach().numpy()), axis=0) if labels_array.size else labels.cpu().detach().numpy()

    num_clusters   = 40
    k_means        = K_means(n_clusters=num_clusters)
    clustering_acc_proj = k_means.transform(image_featvec_proj, labels_array)
    print("[Epoch: {}, Val KMeans score Proj: {}]".format(epoch, clustering_acc_proj))
    model.train()
    return running_loss, clustering_acc_proj



    

# ## hyperparameters
batch_size     = batch_size
EPOCHS         = 8000
device = "cuda"




model       = ImgEncoder()
model     = torch.nn.DataParallel(model).to(device)
optimizer = torch.optim.Adam(list(model.parameters()),lr=3e-4,betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=8000, eta_min=0, last_epoch=-1)




START_EPOCH = 0
pre = True
if pre:
    ckpt_path  = '/home/ubuntu/bestckpt/eegfeat_all.pth'
    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    START_EPOCH = checkpoint['epoch']
    os.makedirs('bestckpt/')
    START_EPOCH += 1
else:
    os.makedirs('bestckpt/')



miner   = miners.MultiSimilarityMiner()
loss_fn = losses.TripletMarginLoss()




best_val_acc   = 0.0
best_val_epoch = 0
EPOCHS = 8000

for epoch in range(START_EPOCH, EPOCHS):

    running_train_loss = train(epoch, model, optimizer, loss_fn, miner, train_loader)
    running_val_loss, val_acc   = validation(epoch, model, optimizer, loss_fn, miner, test_loader)
    scheduler.step()


    if best_val_acc < val_acc:
            best_val_acc   = val_acc
            best_val_epoch = epoch
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, 'bestckpt/imagefeat_{}.pth'.format('all'))


